#!/usr/bin/python3

# This script uses tapset/hit-count.stp to profile a specific process
# or the kernel. It may take a context width, module path, pid, cmd, and timeout.
# It generates folders based on buildid, containing subdirectories
# leading to sourcefiles where one may read how many times the pc
# was at a certain line in that sourcefile.


import argparse
import sys
import os
import re
import subprocess
import tempfile
from collections import defaultdict

parser = argparse.ArgumentParser()
pid_cmd_group = parser.add_mutually_exclusive_group()
pid_cmd_group.add_argument("-x", "--pid", help='PID for systemtap to target.', type=int)
pid_cmd_group.add_argument("-c", "--cmd", help='Command for systemtap to target.', type=str)
parser.add_argument('-d', metavar="BINARY", help='Add symbol information for given binary and its shared libraries.', type=str, action='append', default=[])
parser.add_argument("-e", "--events", help='Override the list of profiling probe points.', type=str, default='timer.profile')
parser.add_argument("-T", "--timeout", help="Exit in 'timeout' seconds.", type=int)
parser.add_argument("-p", "--print", help="Print annotated source files to stdout instead of files.", action='store_true')
parser.add_argument("-w", "--context-width", metavar="WIDTH", help='Limit number of lines of context around each hit.  Defaults to unlimited.', type=int, default=-1)
parser.add_argument("-s", "--stap", metavar="PATH", help='Override the path to the stap interpreter.', type=str)
parser.add_argument("-v", "--verbose", help="Increase verbosity.", action='count', default=0)

args = parser.parse_args()
verbosity = args.verbose
DB_URLS = os.getenv("DEBUGINFOD_URLS")

def vprint(level,*args):
    if (verbosity >= level):
        print(*args)


stap_script="""
global count
global unknown
global kernel
global user
probe begin {
  system(\"echo Starting stap data collector.\") # sent to stdout of stap-profile-annotate process
}
probe """ + args.events + """ {
  if (! user_mode()) {
    kernel <<< 1
    next
  }
  try {
    if (target()==0 || target_set_pid(pid()))
      {
        buildid = umodbuildid(uaddr());
        addr= umodaddr(uaddr());
        count[buildid,addr] <<< 1;
        user <<< 1
      }
  }
  catch /*(e)*/ { unknown <<< 1 /* printf ("%s", e) */ }
}

probe timer.s(1),end
{
  println (\"BEGIN\");
  foreach ( [buildid, addr] in count)
    {
      c = @count(count[buildid,addr]);
      println(buildid, " " , addr, " ", c);
    }
  println (\"END\");
  delete count
}
probe end,error
{
  printf (\"Counted %d known userspace hits.\\n\", @count(user))
  if (@count(kernel))
    printf (\"Ignored %d kernel hits.\\n\", @count(kernel))
  if (@count(unknown))
    printf (\"Ignored %d unknown userspace hits.\\n\", @count(unknown))
  println(\"Stopped stap data collector.\")
}
"""

# buildid class
class BuildIDProfile:
    def __init__(self,buildid):
        self.counts = defaultdict(lambda: 0)
        self.buildid = buildid
        self.filename = self.buildid + 'addrs.txt'
        self.sources = {}

    def __str__(self):
        return "BuildIDProfile(buildid %s) items: %s sources: %s" % (self.buildid, self.counts.items(), self.sources.items())
    
    # Build the 'counts' dict by adding the hit count to its associated address
    def accumulate(self,pc,count):
        self.counts[pc] += count

    # Get the Find the sources of relative addresses from self.counts.keys()
    def get_sources(self):
        vprint(1,"Computing addr2line for %s" % (self.buildid))
        # Used to maintain order of writing
        ordered_keys = list(self.counts.keys())
        # create addr file in /tmp/
        with open('/tmp/'+self.filename, 'w') as f:
            for k in ordered_keys:
                f.write(str(hex(k)) + '\n')
        vprint(2,"Dumped addresses")
        # Get source:linenum info 
        dbginfo = self.get_debuginfo()
        # Split the lines into a list divided by newlines
        lines = dbginfo.split('\n')

        for i in range(0,len(lines)):
            if lines[i] == '':
                continue
            split = lines[i].split(':')
            src = split[0]
            line_number = split[1]
            if line_number == None:
                continue
            if src not in self.sources.keys():
                self.sources[src] = SourceLineProfile(self.buildid,src)
 
            # Sometimes addr2line reponds with a string of format ("linenum" discriminator "num")
            # trim this to yield "linenum" using a regular expression:
            m = re.search('[0-9]+',line_number)
            # If m doesn't contain the above regex, it has no number so don't accumulate it
            if m == None:
                continue
            line_number = int(m.group(0))
            # eu-addr2line gives outputs beginning at 1, where as in SourceLineProfiler.report
            # the line numbering begins at 0. This offset of 1 must be reomved from eu-addr2line
            # to ensure compatibility with SourceLineProfiler.report
            self.sources[src].accumulate(line_number-1, self.counts[ordered_keys[i]])
        vprint(2,"Mapped to %d source files" % (len(self.sources),))
        # Remove tempfile
        os.remove('/tmp/'+self.filename)

    # Report information for this buildid's source files
    def report(self,totalhits):
        for so in self.sources.values():
            so.report(totalhits)

    # Get source:linenum information. Assumes self.filename has relative address information
    def get_debuginfo(self):
        try:
            #Get the debuginfo of the bulidid retrieved from stap
            p = subprocess.Popen(['debuginfod-find', 'debuginfo', self.buildid],stdout=subprocess.PIPE)
            dbg_file,err = p.communicate()
            dbg_file = dbg_file.decode('utf-8').rstrip()
            if dbg_file == '' or dbg_file == None:
                raise Exception("No debug file for bid %s from debuginfod servers: %s" % (self.bid, DB_URLS))
            elif err != '' and err != None:
                raise Exception(err.decode('utf-8').rstrip())
            vprint(2, "Stored debuginfod-find debuginfo file as %s" % (dbg_file))
            #Use the debuginfo attained from the above process
            process = subprocess.Popen(['sh','-c', 'eu-addr2line -A -e '  + dbg_file + ' < /tmp/' + self.filename],  stdout=subprocess.PIPE)
            out,err = process.communicate()
        except Exception as e:
            print (e)
        return out.decode('utf-8')


# Contains information related to each source of a buildid
class SourceLineProfile:
    def __init__(self,  bid, source):
        self.bid = bid
        self.source = source
        self.counts = defaultdict(lambda: 0)

    def __str__(self):
        return "SourceLineProfile(bid %s, source %s) counts: %s" % (self.bid, self.source, self.counts.items())

    # Accumulate hits on a line
    def accumulate(self, line, count):
        self.counts[line] += count

    # Get the source file associated with a buildid
    def get_source_file(self):
        try: 
            p = subprocess.Popen(['debuginfod-find', 'source', self.bid, self.source],stdout=subprocess.PIPE)
            sourcefile,err = p.communicate()
            sourcefile = sourcefile.decode('utf-8').rstrip()
            if sourcefile == '' or sourcefile == None:
                raise Exception("No source file for bid %s, source %s from debuginfod servers: %s" % (self.bid, self.source, DB_URL))
            elif err != '' and err != None:
                raise Exception(err.decode('utf-8').rstrip())
            vprint(2, "Stored debuginfod-find source file as %s" % (sourcefile))
            return sourcefile
        except Exception as e:
            print (e)

    # Reporting function for the source file
    def report(self, totalhits):
        filehits=sum(self.counts.values())
        if self.source == '??' or self.source == '':
            vprint(0,"%08d (%.2f%%) hits in buildid %s with unknown source" % (filehits, filehits/totalhits*100,
                                                                               self.bid))
            return
        # Retrieve the sourcefile's name 
        sourcefile = self.get_source_file()
        if sourcefile == None or sourcefile == '':
            return 

        outfile = os.path.join('profile-'+self.bid, (sourcefile.split('/')[-1]).replace('##','/'))

        # Try creating the appropriate directory
        if not args.print:
            try:
                # Begins at -1 so that when the for loop counts the profile-buildid directory the
                # above_profile_dir is set to 0 (the intended beginning position)
                # This saves having to either remove profile-buildid or check for it each iteration
                # This variable represents how many directories we are above the profile-buildid
                # directory
                above_profile_dir = -1
                for word in '/'.split(outfile):
                    if word == "..":
                        above_profile_dir -=1
                    else:
                        above_profile_dir += 1
                    if above_profile_dir < 0:
                        raise Exception(outfile + " descends beyond its intended root directory, profile-"+self.bid+".\nEnsuring the directory remains above profile-"+self.bid+" ... ")
                outfile = re.sub("\/\.\.","/dotdot", outfile)
                if not os.path.isfile(outfile):
                    os.makedirs(os.path.dirname(outfile))
            except Exception as e:
                print(e)

        # Output source code to 'outfile' and if a line has associated hits (read out of sourcefile)
        # then add the line number and hit count before that line. If a context_width is present use
        # print the surrounding lines for context in accordance with context_width
        vprint(0,"%07d (%.2f%%) hits in %s over %d lines." % (filehits, filehits/totalhits*100,
                                                             outfile, len(self.counts)))
        class blob:
            def __init__(self, lower, upper, hit):
                self.lower = lower
                self.upper = upper
                self.hits = []
                self.hits.append(hit)
            def __str__(self):
                if self.lower != self.upper:
                    return ("Hits: " + ', '.join(str(i) for i in self.hits) + ". Context from lines %s to %s") % (self.lower, self.upper)
                else:
                    return ("Hits: " + ', '.join(str(i) for i in self.hits) + ". Context of line %s") % (self.upper)

            def get_context(self):
                return "//" + str(self) +"\n"

        num_lines = sum(1 for line in open(sourcefile,'r')) - 1
        with open(sourcefile,'r') as f, open(outfile, 'w') as of:
            hitlines = sorted( list(self.counts.keys()) )
            width = -1
            if args.context_width >= 0:
                width = int(args.context_width)
            else:
                width = sys.maxsize
            upper_bound = sys.maxsize if width == sys.maxsize else hitlines[0]+width
            lower_bound = -1 if width == sys.maxsize else hitlines[0] - width
            # Set the first upper and lower bounds
            context_blobs = []
            context_blobs.append(blob(lower_bound, upper_bound, hitlines[0]))
            blob_num = 0
            for i in hitlines[1:]:
                lower = i-width
                upper = i+width
                # - 1 to connect blobs bordering one another
                if context_blobs[blob_num].upper >= lower-1:
                    context_blobs[blob_num].upper = upper
                    context_blobs[blob_num].hits.append(i)
                else:
                    blob_num = blob_num+1
                    context_blobs.append(blob(lower, upper, i))
            context_blobs[-1].upper = num_lines if context_blobs[-1].upper > num_lines else context_blobs[-1].upper
            for linenum, line, in list(enumerate(f)):
                # Convenience variable
                hits = context_blobs[0].hits
                # If we've passed this blobs area of context, pop it
                if context_blobs and context_blobs[0].upper < linenum:
                    context_blobs.pop(0)
                if not context_blobs:
                    break
                # If we have reached the beginning of a blob's context,
                # print_context()
                if context_blobs and linenum == context_blobs[0].lower:
                    of.write(context_blobs[0].get_context())

                # If we have found a line with hits, output info
                # otherwise if there is no width, don't take it into account
                # otherwise if the current line is within the desired width
                #  print it for context
                if linenum in hits:
                    of.write("%07d %s\n" % ( self.counts[linenum], line.rstrip()))
                elif width == -1:
                    of.write("%7s %s\n" % ("", line))
                elif context_blobs[0].lower <= linenum and linenum <= context_blobs[0].upper:
                    of.write("%7s %s\n" % ("" , line.rstrip()))

            if not args.print: # don't close stdout
                of.close()

def __main__():
    # We require $DEBUGINFOD_URLS
    if (not DB_URLS):
        raise Exception("Required DEBUGINFOD_URLS is unset.")
    
    # Run SystemTap
    (tmpfd,tmpfilename) = tempfile.mkstemp()
    stap_cmd = "@prefix@/bin/stap"  # not @ bindir @ because autoconf expands that to shell $var expressions
    stap_args = ['--ldd', '-o'+tmpfilename]

    if args.cmd:
        stap_args += ['-c', args.cmd]
    if args.timeout:
        if args.timeout < 0:
            raise Exception("Timeout must be positive")
        stap_args += ['-T', str(args.timeout)]
    if args.pid:
        if args.pid < 0:
            raise Exception("pid must be positive")
        stap_args += ['-x', str(args.pid)]
    for d in args.d:
        stap_args += ['-d', d]
    if args.stap:
        stap_cmd = args.stap
    if args.context_width and args.context_width < -1:
        raise Exception("context_width must be positive or -1 (for all file)")
    stap_args += ['-e', stap_script]

    vprint(1,"Building stap data collector.")
    vprint(2,"%s %s" % (stap_cmd, stap_args))

    try:
        p = subprocess.Popen([stap_cmd] + stap_args)
        p.communicate() # wait until process exits
    except KeyboardInterrupt:
        pass
    p.kill()
    
    buildids = {} # dict from buildid hexcode to BuildIdProfile object
    
    outp_begin = False
    proflines = 0
    totalhits = 0

    for line in open(tmpfilename,"r"): # read stap output, text mode
        line = line.rstrip()
        # All relevant output is after BEGIN and before END
        if "BEGIN" in line:
            outp_begin = True
        elif "END" in line:
            outp_begin = False
        elif outp_begin == False:
            if line != "": # diagnostic message
                vprint(0,line)
            else:
                pass
        else: # an actual profile record
            try:
                proflines += 1
                (buildid,pc,hits) = line.split()
                vprint(3,"(%s,%s,%s)" % (buildid,pc,hits))
                totalhits += int(hits)
                bidp = buildids.setdefault(buildid, BuildIDProfile(buildid))
                # Accumulate hits for offset pc
                bidp.accumulate(int(pc),int(hits))
            except Exception as e: # parse error?
                vprint(2,e)

    os.remove(tmpfilename)
        
    vprint(0, "Consumed %d profile records of %d hits across %d buildids." % (proflines, totalhits, len(buildids)))
        
    # Output source information for each buildid
    totalhits = sum([sum(bid.counts.values()) for bid in buildids.values()])
    for buildid, bidp in buildids.items():
        bidp.get_sources()
        bidp.report(totalhits)

if __name__ == '__main__':
    __main__()
